import numpy as np
import dezero.functions as F
from dezero.datasets import Spiral
from dezero import DataLoader

#=======test DataLoader
batch_size = 10
max_epoch = 1
#训练数据
train_set = Spiral(train=True)
train_loader = DataLoader(train_set, batch_size)
#测试数据
test_set = Spiral(train=False)
test_loader = DataLoader(train_set, batch_size,shuffle=False)

for epoch in range(max_epoch):
    for x,t in train_loader:
        print(x.shape,t.shape)
        break
    for x,t in test_loader:
        print(x.shape,t.shape)
        break

#=======test Accuracy
#3类分类
y = np.array([[0.2, 0.8, 0], [0.1, 0.9, 0], [0.8, 0.1, 0.1]])
t = np.array([2,1,0]) #训练数据t是每个样本数据的正确答案的索引
acc = F.accuracy(y,t)
print(f'评估精度为 {acc.data}')
